Skip to content

[BE][export] add data-dependent section to export tutorial #3244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 23, 2025

Conversation

pianpwk
Copy link
Contributor

@pianpwk pianpwk commented Jan 22, 2025

Adds brief section, with pointer to more in-depth data-dependent errors doc

Checklist

  • The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER")
  • Only one issue is addressed in this pull request
  • Labels from the issue that this PR is fixing are added to this pull request
  • No unnecessary issues are included into this pull request.

cc @williamwen42 @msaroufim @anijain2305

Copy link

pytorch-bot bot commented Jan 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3244

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit db1b6a6 with merge base 5a5edfc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@svekars svekars left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a couple of editorial nits.

# While trying to export models, you have may have encountered errors like ``Could not guard on data-dependent expression`` or ``Could not extract specialized integer from data-dependent expression``.
# Obscure as they may seem, the reasoning behind their existence, and their resolution, is actually quite straightforward.
#
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (for example, they may have the same or equivalent symbolic properties

# Obscure as they may seem, the reasoning behind their existence, and their resolution, is actually quite straightforward.
#
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties
# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may
# - sizes, strides, dtypes, and so on), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may

# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts
# allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence, or absence, of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler how to proceed.
#
# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.
# For dynamic input shapes (backed ``SymInts``), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.

#
# Let's talk about where data-dependent values appear in programs. Common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors.
# How are these values represented in the exported program? In the ``Constraints/Dynamic Shapes`` section, we talked about allocating symbols to represent dynamic input dimensions, and the same happens here -
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked" ``SymInts``, in contrast to the "backed" symbols/``SymInts``

# allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence, or absence, of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler how to proceed.
#
# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.
# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts".
# On the other hand, data-dependent values are derived from ``FakeTensors`` during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols" or ``SymInts``.

@svekars svekars added the torch.compile Torch compile and other relevant tutorials label Jan 22, 2025
@pianpwk pianpwk changed the title init test branch [export] add data-dependent section to export tutorial Jan 23, 2025
@pianpwk pianpwk marked this pull request as ready for review January 23, 2025 00:17
@pianpwk pianpwk changed the title [export] add data-dependent section to export tutorial [BE][export] add data-dependent section to export tutorial Jan 23, 2025
Copy link

@yushangdi yushangdi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Thanks for the nice tutorial!

Copy link

@avikchaudhuri avikchaudhuri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address nits

# ---------------------
#
# While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or "Could not extract specialized integer from data-dependent expression".
# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cut "For example, ...may"

You can rephrase by highlighting the main difference first, or use "While they have equivalent..., they diverge in that..."

#
# While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or "Could not extract specialized integer from data-dependent expression".
# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties
# (e.g. sizes, strides, dtypes), but diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may struggle

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reword "may struggle"

# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties
# (e.g. sizes, strides, dtypes), but diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may struggle
# with parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that
# FakeTensor tracing isn't providing the information required.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking, you should highlight data-dependence as the main cause rather than making it about fake tensors. Also, export programming model has sections on shape vs. data dependence.

# How are these values represented in the exported program? In the `Constraints/Dynamic Shapes <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes>`_
# section, we talked about allocating symbols to represent dynamic input dimensions.
# The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts",
# in contrast to the "backed" symbols/SymInts allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence/absence of a "hint" for the symbol:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

symbols / SymInts vs. symbols or SymInts, choose one

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, the export programming model explains what backed vs. unbacked means, so maybe link there.

# The result is that 3 unbacked symbols (notice they're prefixed with "u", instead of the usual "s" for input shape/backed symbols) are allocated and returned:
# 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call.
# Note from the range constraints field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols,
# since we literally have no information on what these values are - they don't represent sizes, so don't necessarily have positive values.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut "literally"

# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5`` as the output.
# Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5 (unhinted)``".
#
# So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cut "actually"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"this" what? maybe "this code"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some options:

@pianpwk pianpwk merged commit 37e0b1e into main Jan 23, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed torch.compile Torch compile and other relevant tutorials
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants